#!/usr/bin/env python
# -*- coding: utf-8; py-indent-offset:4 -*-
###############################################################################
#
# Copyright (C) 2017 Ed Bartosh <bartosh@gmail.com>
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program.  If not, see <http://www.gnu.org/licenses/>.
#
###############################################################################
from __future__ import (absolute_import, division, print_function,
                        unicode_literals)

import time
from datetime import datetime
from functools import wraps

import backtrader as bt
import ccxt
from backtrader.metabase import MetaParams
from backtrader.utils.py3 import with_metaclass
from ccxt.base.errors import NetworkError, ExchangeError, OrderNotFound
from backtrader.utils import logger


class MetaSingleton(MetaParams):
    '''Metaclass to make a metaclassed class a singleton'''

    def __init__(cls, name, bases, dct):
        super(MetaSingleton, cls).__init__(name, bases, dct)
        cls._singleton = None

    def __call__(cls, *args, **kwargs):
        if cls._singleton is None:
            cls._singleton = (
                super(MetaSingleton, cls).__call__(*args, **kwargs))

        return cls._singleton

'''
如果有新的交易所加进来, 必须这样命名
'''
class Okex5Store(with_metaclass(MetaSingleton, object)):
    '''API provider for CCXT feed and broker classes.

    Added a new get_wallet_balance method. This will allow manual checking of the balance.
        The method will allow setting parameters. Useful for getting margin balances

    Added new private_end_point method to allow using any private non-unified end point

    '''

    # Supported granularities
    _GRANULARITIES = {
        (bt.TimeFrame.Minutes, 1): '1m',
        (bt.TimeFrame.Minutes, 3): '3m',
        (bt.TimeFrame.Minutes, 5): '5m',
        (bt.TimeFrame.Minutes, 15): '15m',
        (bt.TimeFrame.Minutes, 30): '30m',
        (bt.TimeFrame.Minutes, 60): '1h',
        (bt.TimeFrame.Minutes, 90): '90m',
        (bt.TimeFrame.Minutes, 120): '2h',
        (bt.TimeFrame.Minutes, 180): '3h',
        (bt.TimeFrame.Minutes, 240): '4h',
        (bt.TimeFrame.Minutes, 360): '6h',
        (bt.TimeFrame.Minutes, 480): '8h',
        (bt.TimeFrame.Minutes, 720): '12h',
        (bt.TimeFrame.Days, 1): '1d',
        (bt.TimeFrame.Days, 3): '3d',
        (bt.TimeFrame.Weeks, 1): '1w',
        (bt.TimeFrame.Weeks, 2): '2w',
        (bt.TimeFrame.Months, 1): '1M',
        (bt.TimeFrame.Months, 3): '3M',
        (bt.TimeFrame.Months, 6): '6M',
        (bt.TimeFrame.Years, 1): '1y',
    }

    log = logger.getLogger(__name__)

    BrokerCls = None  # broker class will auto register
    DataCls = None  # data class will auto register

    exchange_config = {
        "options": {
            "createMarketBuyOrderRequiresPrice": True,
        },
        "enableRateLimit": True
    }

    broker_mapping = {
        'order_types': {
            bt.Order.Market: 'market',
            bt.Order.Limit: 'limit'
        },
        'mappings': {
            'closed_order': {
                'key': 'status',
                'value': 'closed'
            },
            'canceled_order': {
                'key': 'status',
                'value': 'canceled'
            },
            "order_result": {
                'id': 'id',
                'clientOrderId': 'clientOrderId',
                'symbol': 'symbol',
                'exectype': 'type',
                'side': "side",  #
                'price': "price",  # 委托价
                'stop_price': "stopPrice",  # 止损止盈触发价
                'size': "amount",
                'order_time': "cTime",  # 成交时间
                'trade_time': "uTime",  # 成交时间
                'trade_dt': "timestamp",  # 成交时间
                'trade_price': "average",  # 成交价
            },
            'closed_strat_order': {
                'key': 'state',
                'value': 'effective'
            },
            'canceled_strat_order': {
                'key': 'state',
                'value': 'canceled'
            },
            "order_strat_result": {
                'id': 'algoId',
                'orderid': 'ordId',
                'symbol': 'instId',
                'exectype': 'ordType',
                'side': "side",  #
                'loss_price': "slOrdPx",  # 委托价
                'trigger_loss_price': "slTriggerPx",  # 止损止盈触发价
                'profit_price': "tpOrdPx",  # 委托价
                'trigger_profit_price': "tpTriggerPx",  # 止损止盈触发价
                'size': "sz",
                'trade_dt': "triggerTime",  # 成交时间
                'order_time': "cTime",
            },
        }
    }

    @classmethod
    def getdata(cls, *args, **kwargs):
        '''Returns ``DataCls`` with args, kwargs'''
        return cls.DataCls(*args, **kwargs)

    @classmethod
    def getbroker(cls, *args, **kwargs):
        '''Returns broker with *args, **kwargs from registered ``BrokerCls``'''
        return cls.BrokerCls(*args, **kwargs)

    def __init__(self, exchange, currency, config, retries, debug=False, sandbox=False):
        self.exchange = getattr(ccxt, exchange)(config)
        if sandbox:
            self.exchange.set_sandbox_mode(True)
        self.retries = retries
        self.debug = debug
        self.currency = currency + ["USDT"]

    def get_granularity(self, timeframe, compression):
        if not self.exchange.has['fetchOHLCV']:
            raise NotImplementedError("'%s' exchange doesn't support fetching OHLCV data" % \
                                      self.exchange.name)

        granularity = self._GRANULARITIES.get((timeframe, compression))
        if granularity is None:
            raise ValueError("backtrader CCXT module doesn't support fetching OHLCV "
                             "data for time frame %s, comression %s" % \
                             (bt.TimeFrame.getname(timeframe), compression))

        if self.exchange.timeframes and granularity not in self.exchange.timeframes:
            raise ValueError("'%s' exchange doesn't support fetching OHLCV data for "
                             "%s time frame" % (self.exchange.name, granularity))

        return granularity

    def retry(method):
        @wraps(method)
        def retry_method(self, *args, **kwargs):
            for i in range(self.retries):
                if self.debug:
                    print('{} - {} - Attempt {}'.format(datetime.now(), method.__name__, i))
                time.sleep(self.exchange.rateLimit / 1000)
                try:
                    return method(self, *args, **kwargs)
                except (NetworkError, ExchangeError):
                    Okex5Store.log.info('%s - %s - Attempt %s', datetime.now(), method.__name__, i)
                    if i == self.retries - 1:
                        raise

        return retry_method

    @retry
    def get_wallet(self, params=None):
        return self.exchange.fetch_balance(params)

    @retry
    def fetch_market(self, symbol):
        self.exchange.load_markets()
        markets = self.exchange.market(symbol)
        return markets

    """
    创建普通订单 ADA/USDT
    """
    @retry
    def create_order(self, symbol, order_type, side, amount, price, params):
        # returns the order
        return self.exchange.create_order(symbol=symbol, type=order_type, side=side,
                                          amount=amount, price=price, params=params)
    """
    创建策略下单(目前只支持现货)
    state = live 在策略未成交列表中, orderid = 0
    state = effective 在策略已成交列表中, 并且orderid存在, 普通成交表有挂单
    state = canceled 在策略已成交列表中, orderid = 0
    {'code': '0', 'data': [{'algoId': '388788059810201610', 'sCode': '0', 'sMsg': ''}], 'msg': ''}
    """
    @retry
    def create_strat_order(self, symbol, order_type, side, amount, **kwargs):
        tdMode = kwargs["tdMode"] if "tdMode" in kwargs else "cash"
        instId = symbol.replace('/', "-")
        params = {
            "instId": instId,
            "tdMode": tdMode,  # 逐仓 isolated：逐仓 ；cross：全仓 非保证金模式：cash：非保证金
            "side": side,  # 操作方向是买还是卖
        }
        loss_price, trigger_loss, profit_price, trigger_profit = None, None, None, None
        if "loss_price" in kwargs and "trigger_loss" in kwargs:
            loss_price, trigger_loss = kwargs["loss_price"], kwargs["trigger_loss"]

        if "profit_price" in kwargs and "trigger_profit" in kwargs:
            profit_price, trigger_profit = kwargs["profit_price"], kwargs["trigger_profit"]

        if order_type == bt.Order.Stop:
            # 市价止损
            params.update({
                "sz": amount,
                "ordType": "conditional",
                "slTriggerPx": trigger_loss,
                "slOrdPx": -1
            })
        elif order_type == bt.Order.StopLimit:
            # 限价止损
            params.update({
                "sz": amount,
                "ordType": "conditional",
                "slTriggerPx": trigger_loss,
                "slOrdPx": loss_price
            })
        elif order_type == bt.Order.Profit:
            # 市价止盈
            params.update({
                "sz": amount,
                "ordType": "conditional",
                "tpTriggerPx": trigger_profit,
                "tpOrdPx": -1
            })
        elif order_type == bt.Order.ProfitLimit:
            # 限价止盈
            params.update({
                "sz": amount,
                "ordType": "conditional",
                "tpTriggerPx": trigger_profit,
                "tpOrdPx": profit_price
            })
        elif order_type == bt.Order.StopProfit:
            # 市价止盈止损
            params.update({
                "sz": amount,
                "ordType": "oco",
                "slTriggerPx": trigger_loss,
                "slOrdPx": -1,
                "tpTriggerPx": trigger_profit,
                "tpOrdPx": -1
            })
        elif order_type == bt.Order.StopProfitLimit:
            # 限价止盈止损
            params.update({
                "sz": amount,
                "ordType": "oco",
                "slTriggerPx": trigger_loss,
                "slOrdPx": loss_price,
                "tpTriggerPx": trigger_profit,
                "tpOrdPx": profit_price
            })
        return self.exchange.private_post_trade_order_algo(params=params)['data'][0]

    @retry
    def cancel_order(self, order_id, symbol):
        return self.exchange.cancel_order(order_id, symbol)

    @retry
    def cancel_strat_order(self, symbol, algoid, **kwargs):
        instId = symbol.replace('/', "-")
        params = [
            {
                "algoId": algoid,
                "instId": instId
            }
        ]
        try:
            # {'code': '0', 'data': [{'algoId': '388788059810201610', 'sCode': '0', 'sMsg': ''}], 'msg': ''}
            rs = self.exchange.private_post_trade_cancel_algos(params=params)["data"]
            if len(rs) > 0 and rs[0]["sCode"] == "0":
                # 取消成功, 取历史的数据
                return self.fetch_strat_order(symbol, bt.Order.StopLimit, algoid)
            else:
                # 取消失败, 取未成交的数据
                return self.fetch_open_strat_orders(symbol, bt.Order.StopLimit, algoId=algoid)[0]
        except OrderNotFound as e:
            # 说明已经被取消了, 取历史的数据
            return self.fetch_strat_order(symbol, bt.Order.StopLimit, algoid)

    @retry
    def fetch_ohlcv(self, symbol, timeframe, since, limit, params={}):
        if self.debug:
            print('Fetching: {}, TF: {}, Since: {}, Limit: {}'.format(symbol, timeframe, since, limit))
        return self.exchange.fetch_ohlcv(symbol, timeframe=timeframe, since=since, limit=limit, params=params)

    """
    指定订单详情
    """
    @retry
    def fetch_order(self, oid, symbol):
        return self.exchange.fetch_order(oid, symbol)

    """
    指定策略订单详情
    如果不存在, 说明还没有生效, 调用这个方法是在调用未成交列表的时候
    """
    def fetch_strat_order(self, symbol, order_type, algoid):
        params = {}
        if order_type in [bt.Order.Stop, bt.Order.StopLimit, bt.Order.Profit, bt.Order.ProfitLimit]:
            params["ordType"] = "conditional"

        elif order_type in [bt.Order.StopProfit, bt.Order.StopProfitLimit]:
            params["ordType"] = "oco"

        if symbol:
            params["instId"] = symbol.replace("/", "-")

        if algoid:
            params["algoId"] = algoid

        # 先从未成交列表找
        order = self.fetch_open_strat_orders(symbol, order_type, algoId=algoid)
        if not order:
            # 找不到就从历史成交找
            order = self.exchange.private_get_trade_orders_algo_history(params=params)["data"]
        # 没有抛异常说明就至少有一条数据
        return self.replace_symbol(order[0])

    """
    根据orderid反找策略委托单
    """
    def search_strat_order(self, symbol, orderid):
        def find_by_id(orders, order_id):
            for ord in orders:
                if ord["ordId"] == order_id:
                    return ord
            return None

        params = {
            "ordType": "conditional",
            "instId": symbol.replace("/", "-"),
            "state": "effective"
        }
        orders = self.exchange.private_get_trade_orders_algo_history(params=params)["data"]
        if not orders or not find_by_id(orders, orderid):
            params["ordType"] = "oco"
            orders = self.exchange.private_get_trade_orders_algo_history(params=params)["data"]
        if not orders or not find_by_id(orders, orderid):
            return None
        return self.replace_symbol(find_by_id(orders, orderid))


    """
    获取未成交订单列表, 只是okex买单和卖单，不是止损止盈单
    """
    @retry
    def fetch_open_orders(self):
        return self.exchange.fetchOpenOrders()

    def replace_symbol(self, data):
        if not data:
            return data

        if isinstance(data, dict):
            data["instId"] = data["instId"].replace("-", "/")
        elif isinstance(data, list):
            for d in data:
                d["instId"] = d["instId"].replace("-", "/")
        return data
    """
    未成交的列表
    """
    @retry
    def fetch_open_strat_orders(self, symbol, order_type, **kwargs):
        params = {}
        if not order_type:
            raise Exception("order_type must given.")

        if order_type in [bt.Order.Stop, bt.Order.StopLimit, bt.Order.Profit, bt.Order.ProfitLimit]:
            params["ordType"] = "conditional"

        elif order_type in [bt.Order.StopProfit, bt.Order.StopProfitLimit]:
            params["ordType"] = "oco"

        if symbol:
            params["instId"] = symbol.replace("/", "-")

        if "algoId" in kwargs:
            params["algoId"] = kwargs["algoId"]

        try:
            # order 不存在, 证明已经成交了
            # 返回是一个数组state live：待生效 pause：暂停生效 partially_effective:部分生效
            rs = self.exchange.private_get_trade_orders_algo_pending(params=params)["data"]
            return self.replace_symbol(rs)
        except OrderNotFound as e:
            return []

    @retry
    def private_end_point(self, type, endpoint, params):
        '''
        Open method to allow calls to be made to any private end point.
        See here: https://github.com/ccxt/ccxt/wiki/Manual#implicit-api-methods

        - type: String, 'Get', 'Post','Put' or 'Delete'.
        - endpoint = String containing the endpoint address eg. 'order/{id}/cancel'
        - Params: Dict: An implicit method takes a dictionary of parameters, sends
          the request to the exchange and returns an exchange-specific JSON
          result from the API as is, unparsed.

        To get a list of all available methods with an exchange instance,
        including implicit methods and unified methods you can simply do the
        following:

        print(dir(ccxt.hitbtc()))
        '''
        return getattr(self.exchange, endpoint)(params)
